{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train 25000 samples | Test 25000 samples\n",
      "Epoch 1/2 | Step 0/781 | train_loss: 0.7418 | train_acc: 0.5312 | lr: 0.0050\n",
      "Epoch 1/2 | Step 50/781 | train_loss: 0.7300 | train_acc: 0.5000 | lr: 0.0047\n",
      "Epoch 1/2 | Step 100/781 | train_loss: 0.6462 | train_acc: 0.5625 | lr: 0.0045\n",
      "Epoch 1/2 | Step 150/781 | train_loss: 0.4495 | train_acc: 0.7500 | lr: 0.0043\n",
      "Epoch 1/2 | Step 200/781 | train_loss: 0.5295 | train_acc: 0.8125 | lr: 0.0041\n",
      "Epoch 1/2 | Step 250/781 | train_loss: 0.3412 | train_acc: 0.8438 | lr: 0.0039\n",
      "Epoch 1/2 | Step 300/781 | train_loss: 0.3876 | train_acc: 0.8125 | lr: 0.0037\n",
      "Epoch 1/2 | Step 350/781 | train_loss: 0.3251 | train_acc: 0.8438 | lr: 0.0035\n",
      "Epoch 1/2 | Step 400/781 | train_loss: 0.5327 | train_acc: 0.7812 | lr: 0.0033\n",
      "Epoch 1/2 | Step 450/781 | train_loss: 0.1768 | train_acc: 0.9375 | lr: 0.0031\n",
      "Epoch 1/2 | Step 500/781 | train_loss: 0.4532 | train_acc: 0.8438 | lr: 0.0030\n",
      "Epoch 1/2 | Step 550/781 | train_loss: 0.4453 | train_acc: 0.7812 | lr: 0.0028\n",
      "Epoch 1/2 | Step 600/781 | train_loss: 0.3405 | train_acc: 0.8438 | lr: 0.0027\n",
      "Epoch 1/2 | Step 650/781 | train_loss: 0.1562 | train_acc: 0.9062 | lr: 0.0026\n",
      "Epoch 1/2 | Step 700/781 | train_loss: 0.3116 | train_acc: 0.8438 | lr: 0.0024\n",
      "Epoch 1/2 | Step 750/781 | train_loss: 0.2540 | train_acc: 0.9688 | lr: 0.0023\n",
      "Epoch 1/2 | train_loss: 0.8004 | train_acc: 0.6250 | test_loss: 0.2542 | test_acc: 0.8936 | lr: 0.0022\n",
      "Epoch 2/2 | Step 0/781 | train_loss: 0.1875 | train_acc: 0.9062 | lr: 0.0022\n",
      "Epoch 2/2 | Step 50/781 | train_loss: 0.1407 | train_acc: 0.9688 | lr: 0.0021\n",
      "Epoch 2/2 | Step 100/781 | train_loss: 0.0812 | train_acc: 0.9688 | lr: 0.0020\n",
      "Epoch 2/2 | Step 150/781 | train_loss: 0.0977 | train_acc: 0.9688 | lr: 0.0019\n",
      "Epoch 2/2 | Step 200/781 | train_loss: 0.2572 | train_acc: 0.8750 | lr: 0.0018\n",
      "Epoch 2/2 | Step 250/781 | train_loss: 0.1185 | train_acc: 0.9688 | lr: 0.0017\n",
      "Epoch 2/2 | Step 300/781 | train_loss: 0.0920 | train_acc: 0.9688 | lr: 0.0016\n",
      "Epoch 2/2 | Step 350/781 | train_loss: 0.2204 | train_acc: 0.9375 | lr: 0.0016\n",
      "Epoch 2/2 | Step 400/781 | train_loss: 0.2689 | train_acc: 0.8750 | lr: 0.0015\n",
      "Epoch 2/2 | Step 450/781 | train_loss: 0.0805 | train_acc: 0.9375 | lr: 0.0014\n",
      "Epoch 2/2 | Step 500/781 | train_loss: 0.0747 | train_acc: 0.9688 | lr: 0.0013\n",
      "Epoch 2/2 | Step 550/781 | train_loss: 0.1461 | train_acc: 0.9062 | lr: 0.0013\n",
      "Epoch 2/2 | Step 600/781 | train_loss: 0.2100 | train_acc: 0.9375 | lr: 0.0012\n",
      "Epoch 2/2 | Step 650/781 | train_loss: 0.1159 | train_acc: 0.9375 | lr: 0.0011\n",
      "Epoch 2/2 | Step 700/781 | train_loss: 0.2800 | train_acc: 0.9062 | lr: 0.0011\n",
      "Epoch 2/2 | Step 750/781 | train_loss: 0.0924 | train_acc: 0.9375 | lr: 0.0010\n",
      "Epoch 2/2 | train_loss: 0.1618 | train_acc: 0.8750 | test_loss: 0.2475 | test_acc: 0.9002 | lr: 0.0010\n",
      "             precision    recall  f1-score   support\n",
      "\n",
      "          0       0.89      0.91      0.90     12500\n",
      "          1       0.91      0.89      0.90     12500\n",
      "\n",
      "avg / total       0.90      0.90      0.90     25000\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from __future__ import print_function\n",
    "from rnn_text_clf import RNNTextClassifier\n",
    "from sklearn.metrics import classification_report\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "vocab_size = 20000\n",
    "batch_size = 32\n",
    "\n",
    "\n",
    "def sort_by_len(x, y):\n",
    "    idx = sorted(range(len(x)), key=lambda i: len(x[i]))\n",
    "    return x[idx], y[idx]\n",
    "\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    (X_train, y_train), (X_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=vocab_size)\n",
    "\n",
    "    X_train, y_train = sort_by_len(X_train, y_train)\n",
    "    X_test, y_test = sort_by_len(X_test, y_test)\n",
    "    \n",
    "    clf = RNNTextClassifier(vocab_size, 2)\n",
    "    log = clf.fit(X_train, y_train, n_epoch=2, batch_size=batch_size, keep_prob=0.8, en_exp_decay=True,\n",
    "                  val_data=(X_test, y_test))\n",
    "    \n",
    "    print(classification_report(y_test, clf.predict(X_test)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
